from time import sleep
import torch
from torch_ac.belief import threshold_rm_beliefs

import utils
from model import ACModel
from recurrent_model import RecurrentACModel
from detector_model import SimpleDetectorModel, RecurrentDectectorModel, PerfectDetector

class Agent:
    """An agent.

    It is able:
    - to choose an action given an observation,
    - to analyze the feedback (i.e. reward and done state) of its action."""

    def __init__(self, env, obs_space, action_space, model_dir,
                rm_update_algo, use_mem=False, use_mem_detector=False,
                no_rm=False, dumb_ac=False, device=None, argmax=False):
        """
        This function should be sync'd with the initialization in train_agent.py
        """
        try:
            print(model_dir)
            status = utils.get_status(model_dir)
        except OSError:
            status = {"num_frames": 0, "update": 0}
        self.env = env
        self.rm_update_algo = rm_update_algo
        self.use_mem = use_mem
        self.no_rm = no_rm
        self.use_rm_belief = (rm_update_algo in ["rm_detector", "rm_threshold", "event_threshold", "independent_belief"])

        obs_space, self.preprocess_obss = utils.get_obss_preprocessor(env)

        if use_mem_detector:
            if rm_update_algo in ["event_threshold", "independent_belief"]:
                detectormodel = RecurrentDectectorModel(obs_space, obs_space['events'])
            elif rm_update_algo in ["rm_detector", "rm_threshold"]:
                detectormodel = RecurrentDectectorModel(obs_space, obs_space['rm_state'])
            elif self.rm_update_algo == "perfect_rm":
                detectormodel = PerfectDetector(obs_space)
            else:
                raise NotImplementedError()
            self.detector_memories = torch.zeros(1, detectormodel.memory_size, device=device)
        else:
            if rm_update_algo in ["event_threshold", "independent_belief"]:
                detectormodel = SimpleDetectorModel(obs_space, obs_space['events'])
            elif rm_update_algo in ["rm_detector", "rm_threshold"]:
                detectormodel = SimpleDetectorModel(obs_space, obs_space['rm_state'])
            elif rm_update_algo == "perfect_rm":
                detectormodel = PerfectDetector(obs_space)
            else:
                raise NotImplementedError()

        if use_mem:
            acmodel = RecurrentACModel(env, obs_space, env.action_space, dumb_ac, no_rm)
            self.memories = torch.zeros(1, acmodel.memory_size, device=device)
        else:
            acmodel = ACModel(env, obs_space, env.action_space, dumb_ac, no_rm)

        self.detectormodel = detectormodel
        self.acmodel = acmodel

        self.device = device
        self.argmax = argmax
        self.num_envs = 1

        self.acmodel.load_state_dict(utils.get_model_state(model_dir))
        self.acmodel.to(self.device)
        self.acmodel.eval()

        self.detectormodel.load_state_dict(utils.get_detector_model_state(model_dir))
        self.detectormodel.to(self.device)
        self.detectormodel.eval()



    def get_action(self, obs, display_rm_belief_as_mission=False):
        """
        This function should be sync'd collect_experiences() in src/torch_ac/algos/base.py
        """

        preprocessed_obs = self.preprocess_obss([obs], device=self.device)

        ## Perform RM updates...
        with torch.no_grad():
            ## Generate a detector belief
            if self.rm_update_algo in ["rm_detector", "rm_threshold", "event_threshold", "independent_belief"]:
                if self.detectormodel.recurrent:
                    detector_belief, self.detector_memories = self.detectormodel(preprocessed_obs, self.detector_memories)
                else:
                    detector_belief = self.detectormodel(preprocessed_obs)

            ## rm_threshold: threshold the beliefs
            if self.rm_update_algo == "rm_threshold":
                detector_belief = threshold_rm_beliefs(detector_belief)

            ## event_threshold: use the event predictions to update RM env
            if self.rm_update_algo == "event_threshold":
                detector_belief = self.env.update_rm_beliefs((detector_belief > 0).cpu().numpy()[0, :])  # diff from collect_experience()
                detector_belief = torch.tensor(detector_belief, device=self.device, dtype=torch.float).unsqueeze(0)

            ## independent_belief: use the event predictions to update RM env
            if self.rm_update_algo == "independent_belief":
                detector_belief = self.env.update_rm_beliefs(detector_belief.cpu().numpy()[0, :])  # diff from collect_experience()
                detector_belief = torch.tensor(detector_belief, device=self.device, dtype=torch.float).unsqueeze(0)

            ## If necessary, add rm_belief to the observation
            if self.rm_update_algo in ["rm_detector", "rm_threshold", "event_threshold", "independent_belief"]:
                preprocessed_obs.rm_belief = detector_belief
                
            if display_rm_belief_as_mission:
                if self.rm_update_algo in ["rm_detector", "rm_threshold"]:
                    self.env.unwrapped.env.mission = [round(u_prob.item(), 2) for u_prob in torch.nn.functional.softmax(detector_belief)[0]]
                if self.rm_update_algo == "event_threshold":
                    self.env.unwrapped.env.mission = [0 if i != self.env.belief_u_id else 1 for i in range(self.env.num_rm_states)]
                if self.rm_update_algo == "independent_belief":
                    self.env.unwrapped.env.mission = [round(u_prob, 2) for u_prob in self.env.belief_u_dist]
            # Get policy action
            if self.acmodel.recurrent:
                dist, value, self.memories = self.acmodel(preprocessed_obs, self.memories, use_rm_belief=self.use_rm_belief)
            else:
                dist, value = self.acmodel(preprocessed_obs, use_rm_belief=self.use_rm_belief)

        if self.argmax:
            actions = dist.probs.max(1, keepdim=True)[1]
        else:
            actions = dist.sample()

        return actions.cpu().numpy()[0]


    def analyze_feedback(self, done):
        masks = 1 - torch.tensor([done], dtype=torch.float).unsqueeze(1)

        if self.acmodel.recurrent:
            self.memories *= masks
        if self.detectormodel.recurrent:
            self.detector_memories *= masks
